/*
    Copyright 2006-2012 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/** \file 

    Implementation of the mpi_master class. 
*/

#ifndef __mpi_master_impl__
#define __mpi_master_impl__

#include "mpi_master.h"

#ifdef HAVE_MPI

#include <boost/mpi/communicator.hpp>
#include <boost/mpi/nonblocking.hpp>
#include <boost/mpi/collectives.hpp>
#include <boost/mpi/skeleton_and_content.hpp>
#include <boost/lexical_cast.hpp>
#include <tbb/tick_count.h>
#include <tbb/atomic.h>
#include <time.h>
#include "hpm.h"
#include "mpi_util.h"
#include "mcrx-debug.h"

const int n_stages = 6;
const char* const stage_names[n_stages] = {"Misc", "Handshake", "Receive", "Receive_test", "Location", "Location_test"};

template<typename xfer_type>
mcrx::mpi_master<xfer_type>::mpi_master (T_xfer& xx, int nt, 
					 tbb::atomic<long>& nr,
					 long nrd)
  : x(xx), n_threads_(nt), n_rays_(nr), n_rays_desired_(nrd),
    we_finished_(false), sent_finish_(false), 
    all_finished_(false), n_resp_(0),
    n_received_(0), i_hpmdump_(0),
    recv_enabled_(false)
{
  n_sent_=0;
  x.set_master(this);

  setup_requests();

  hpm::thread_init(nt, n_stages);
  for(int i=0; i<n_stages; ++i)
    hpm::set_stage_name(i, stage_names[i]);
}

/** Destructor must cancel and complete all outstanding MPI requests
    otherwise these ghost requests will intercept the first messages
    sent in the next stage. This needs to make sure ALL requests
    created in the constructor are explicitly cancelled, because for
    trivial shootings, we don't actually do any communication so none
    of them are ever completed. */
template<typename xfer_type>
mcrx::mpi_master<xfer_type>::~mpi_master ()
{
  using namespace boost::mpi;

  for(int i=0; i<location_reqs_.size(); i++) {
    assert(!location_reqs_[i].test());
    location_reqs_[i].cancel();
    location_reqs_[i].wait();
  }
  for(int i=0; i<location_send_reqs_.size(); i++) {
    if(location_send_reqs_[i].is_initialized()) {
      location_send_reqs_[i].get().cancel();
      location_send_reqs_[i].get().wait();
    }
  }
  for(int i=0; i<ray_reqs_.size(); i++) {
    assert(!ray_reqs_[i].test());
    ray_reqs_[i].cancel();
    ray_reqs_[i].wait();
  }
  for(int i=0; i<handshake_reqs_.size(); i++) {
    assert(!handshake_reqs_[i].test());
    handshake_reqs_[i].cancel();
    handshake_reqs_[i].wait();
  }
  for(int i=0; i<handshake_send_reqs_.size(); i++) {
    if(handshake_send_reqs_[i].is_initialized()) {
      handshake_send_reqs_[i].get().cancel();
      handshake_send_reqs_[i].get().wait();
    }
  }

  // the master also needs to close the master_finish_reqs_
  for(int i=0; i<master_finish_reqs_.size(); ++i) {
    master_finish_reqs_[i].cancel();
    master_finish_reqs_[i].wait();
  }

  finish_req_.cancel();
  finish_req_.wait();
  if(finish_send_req_.is_initialized()) {
    finish_send_req_.get().cancel();
    finish_send_req_.get().wait();
  }
    
  DEBUG(1,printf("Master thread %d exiting\n", task()););
}

/* Initialize the structure of the ray_data entries by exchanging the
   skeleta of the objects. This requires us to actually *have* a valid
   object, which we get from the xfer object. */
template<typename xfer_type>
void mcrx::mpi_master<xfer_type>::init_ray_data()
{
  using namespace boost::mpi;

  ray_recv_data_.resize(n_tasks());

  T_queue_item item;
  x.init_queue_item(item);

  // Now that we have an item, we do a blocking exchange of the skeleta.
  int PTask;
  for(PTask = 0; n_tasks() > (1 << PTask); PTask++);
  for(int ngrp = 1; ngrp < (1 << PTask); ngrp++) {
    const int recvTask = world.rank() ^ ngrp;
    if(world.rank()<recvTask) {
      world.recv(recvTask, ray_tag, skeleton(ray_recv_data_[recvTask]));
      world.send(recvTask, ray_tag, skeleton(item));
    }
    else {
      world.send(recvTask, ray_tag, skeleton(item));
      world.recv(recvTask, ray_tag, skeleton(ray_recv_data_[recvTask]));
    }
  }
}


/** Set up incoming requests */
template<typename xfer_type>
void mcrx::mpi_master<xfer_type>::setup_requests()
{
  using namespace boost::mpi;
  // this barrier is to make sure we don't start posting sends until
  // we know that other tasks have closed the requests from an earlier
  // stage.
  mpi_barrier();

  location_stata_.resize((n_tasks()-1)*n_threads_);
  req_stata_.resize(n_tasks()-1);
  hskreq_stata_.resize(n_tasks()-1);

  // the location data is indexed by task and thread, so we allocate
  // Ntask of them to make it simple to calculate index
  location_data_.resize(n_tasks()*n_threads_);
  location_send_data_.resize(n_tasks()*n_threads_);
  location_send_reqs_.resize(n_tasks()*n_threads_);

  handshake_data_.assign(n_tasks(), 0);
  handshake_data_[task()]=1;
  handshake_send_data_.resize(n_tasks());
  handshake_send_reqs_.resize(n_tasks());
  send_enabled_ = accumulate(handshake_data_.begin(), handshake_data_.end(),
			     true, std::logical_and<int>());

  init_ray_data();

  for(int i=0; i<n_tasks(); ++i) {
    if(task()==0)
      master_finish_reqs_.push_back(world.irecv(i, finish_tag));

    if(i!=task()) {
      ray_reqs_.push_back(world.irecv(i, ray_tag, 
				      get_content(ray_recv_data_[i])));
      handshake_reqs_.push_back(world.irecv(i, handshake_tag, 
					    get_content(handshake_data_[i])));
      
      for(int t=0; t<n_threads_; ++t) {
	location_reqs_.push_back(world.irecv(i, request_thread2tag(t), 
					     get_content(location_data_[i*n_threads_+t])));
      }
    }
  }

  finish_req_ = world.irecv(0, master_finish_tag);
  world.barrier();
}


/** Process handshake messages. */
template<typename xfer_type>
void mcrx::mpi_master<xfer_type>::process_handshakes()
{
  using namespace std; using namespace boost::mpi;
  const int curstage = hpm::current_stage();
  hpm::enter_stage(Handshake);
  bool processed=false;

  // test if any of the handshake requests have completed.  completed.first
  // points to the end of the written req_stata. completed.second
  // points to the first completed request, after partitioning of the
  // vector.
  std::pair<vector<status>::iterator, vector<request>::iterator> completed =
    boost::mpi::test_some(handshake_reqs_.begin(), handshake_reqs_.end(),
			  hskreq_stata_.begin());

  vector<status>::iterator s = hskreq_stata_.begin();
  assert(completed.first-s == handshake_reqs_.end()-completed.second);
  for(vector<request>::iterator r=completed.second; r!=handshake_reqs_.end();
      ++r, ++s) {
    assert(s!=completed.first);

    const int source_task = s->source();
    const int tag = s->tag();
    assert(tag==handshake_tag);

    assert(source_task!=task());
    DEBUG(0,printf("Task %d received handshake message from task %d: %s\n",task(), source_task, handshake_data_[source_task] ? "start" : "stop");cout.flush(););

    if(handshake_data_[source_task]) {
      // if we got an enable, see if all tasks are enabled
      send_enabled_ = accumulate(handshake_data_.begin(), handshake_data_.end(),
				 true, std::logical_and<int>());
    }
    else {
      // if we got a disable, we know to disable
      send_enabled_=false;
    }
    
    // repost the request
    *r = world.irecv(source_task, handshake_tag, 
		     get_content(handshake_data_[source_task]));
  }

  hpm::enter_stage(curstage);
}


template<typename xfer_type>
void mcrx::mpi_master<xfer_type>::send_handshake(bool enable)
{
 using namespace boost::mpi;

  DEBUG(0,printf("Master task %d sending handshake: %s\n",task(), enable ? "start" : "stop");cout.flush(););
  for(int i=0; i<n_tasks(); ++i) {
    if(i!=task()) {
      // wait on previous send, so the request is completed.
      if(handshake_send_reqs_[i].is_initialized())
	handshake_send_reqs_[i].get().wait();

      handshake_send_data_[i]= int(enable ? 1 : 0);
      handshake_send_reqs_[i] = 
	world.isend(i, handshake_tag, get_content(handshake_send_data_[i]));
    }
  }
  recv_enabled_ = enable;
}


/** Process incoming ray requests. */
template<typename xfer_type>
bool mcrx::mpi_master<xfer_type>::process_incoming_rays()
{
  using namespace std; using namespace boost::mpi;
  const int curstage = hpm::current_stage();
  hpm::enter_stage(Receive_test);
  bool processed=false;

  std::pair<vector<status>::iterator, vector<request>::iterator> completed;
  vector<status>::iterator s;
  int i=0;

  // test if any of the ray requests have completed.  completed.first
  // points to the end of the written ray_stata. completed.second
  // points to the first completed request, after partitioning of the
  // vector.
  while((completed = boost::mpi::test_some(ray_reqs_.begin(), ray_reqs_.end(),
					   req_stata_.begin())).second !=
	ray_reqs_.end()) {

    hpm::enter_stage(Receive);

    // loop over the completed requests, if any, and repost
    s = req_stata_.begin();
    for(vector<request>::iterator r=completed.second; r!=ray_reqs_.end();
	++r, ++s) {
      const int source_task = s->source();
      const int tag = s->tag();
      assert(tag==ray_tag);

      // init the tracker and push the ray onto the local queue. Note
      // that we must deep-copy the item, otherwise it will keep
      // referring to the receive buffer.
      T_queue_item item(independent_copy(ray_recv_data_[source_task]));
      x.local_queue()->push(item);

      DEBUG(1,printf("Master task %d received %s ray from task %d, queue depth %ld\n",task(), action_strings[ray_recv_data_[source_task].action_], source_task, x.local_queue()->size()));

      // repost the request
      *r = world.irecv(source_task, ray_tag, 
		       get_content(ray_recv_data_[source_task]));

      processed=true;
      ++n_received_;
    }
    assert(s==completed.first);

    hpm::enter_stage(Receive_test);
    ++i;
    if(i>10000) {
      break;
    }
  }

  hpm::enter_stage(curstage);
  return processed;
}


/** Callback from xfer thread, sending a ray directly from that
    thread. Each thread has a send slot. \todo do we need to worry
    about false sharing here? each thread will look at neighboring
    data. */
template<typename xfer_type>
void mcrx::mpi_master<xfer_type>::thread_send_ray(const T_queue_item& item, 
						  int thread, int dest)
{
  using namespace std; using namespace boost::mpi;

  {
    hpm::scoped_stage current(T_xfer::Send_wait);

    if(tdata_.get())
      // wait until our previous message has completed.
      tdata_->send_req.wait();
    else
      // allocate the thread-specific ptr
      tdata_.reset(new thread_data);
  }

  // put the item in the send buffer and post the send (note that
  // this shallow-copies the arrays, which is fine because they
  // were independently copied when the queue_item was constructed.
  tdata_->send_data = item;
  tdata_->send_req =
    world.isend(dest, ray_tag, get_content(tdata_->send_data));
  ++n_sent_;

  //DEBUG(1,printf("Worker %d-%d direct sending %s ray to task %d\n",task(), thread, action_strings[item.action_], item.cell_.task()));
}


/** Process incoming location requests. */
template<typename xfer_type>
bool mcrx::mpi_master<xfer_type>::process_locations()
{
  using namespace std; using namespace boost::mpi;
  const int curstage = hpm::current_stage();
  hpm::enter_stage(Location_test);
  bool processed=false;

  // test if any of the requests have completed.  completed.first
  // points to the end of the written ray_stata. completed.second
  // points to the first completed request, after partitioning of the
  // vector.
  const std::pair<vector<status>::iterator,
		  vector<request>::iterator>
    completed = boost::mpi::test_some(location_reqs_.begin(), 
				      location_reqs_.end(),
				      location_stata_.begin());

  // loop over the completed receive requests, if any
  if(completed.second!=location_reqs_.end()) {
    hpm::enter_stage(Location);
    vector<status>::iterator s = location_stata_.begin();
    for(vector<request>::iterator r=completed.second; r!=location_reqs_.end(); ++r, ++s) {
      const int source_task = s->source();
      const int thread = request_tag2thread(s->tag());
      
      assert(thread>=0 && thread<n_threads_);
      DEBUG(1,printf("Master task %d responding to location request #%ld from worker %d-%d\n",task(), n_resp_, source_task, thread);cout.flush(););
      n_resp_++;
      
      // process the location request and respond
      const int data_index = source_task*n_threads_+thread;
      typename T_grid::T_location_response response = 
	x.grid().location_request(location_data_[data_index]);
      
      // wait for previous send request to complete, copy data into
      // send buffer and send it.
      if(location_send_reqs_[data_index].is_initialized())
	location_send_reqs_[data_index].get().wait();

      location_send_data_[data_index] = response;
      location_send_reqs_[data_index] =
	world.isend(source_task, response_thread2tag(thread), 
		    get_content(location_send_data_[data_index]));
      
      // repost the request receive
      *r = world.irecv(source_task, request_thread2tag(thread),
		       get_content(location_data_[data_index]));
      processed=true;
    }
    assert(s==completed.first);
  }
    
  hpm::enter_stage(curstage);
  return processed;
}

/** Test whether we are complete. This is surprisingly difficult to do
    in a nonblocking way, since rays may get shipped in after a task
    thinks it's done. Here's our strategy: When a task is done
    producing its rays and has an empty queue, it messages the master
    task. When the master has received NTask such messages, we know
    that *primary* production of rays is complete. However, we might
    still have rays ping-ponging between tasks, so to really know we
    need to do a synchronous reduction.  */
template<typename xfer_type>
bool mcrx::mpi_master<xfer_type>::test_finish()
{
  namespace mpi=boost::mpi;

  const int rank= task();


  // first test if we are done and should send finish msg to master
  if( !we_finished_ && (n_rays_>=n_rays_desired_) && 
      x.local_queue()->empty() && x.pixel_queue()->empty() ) {
    // send finish message to master (even the master sends one to
    // itself because that makes the logic easier)
    finish_send_req_ = world.isend(0, finish_tag);
    we_finished_ = true;
  }

  // if we are the master, see if we have received finish messages
  // from all the other tasks
  if(we_finished_ && !sent_finish_ && rank==0) {
    if(test_all(master_finish_reqs_.begin(), master_finish_reqs_.end())) {
      // we have. send finish messages to all of them. we put those
      // requests in the master_finish_reqs since those requests are
      // now all done
      assert(master_finish_reqs_.size()==n_tasks());
      for(int i=0; i<n_tasks(); ++i)
	master_finish_reqs_[i] = world.isend(i, master_finish_tag);
      sent_finish_=true;
    }
  }

  // if we have finished, test if we have received message from master
  if(we_finished_ && !all_finished_ && finish_req_.test() ) {
    all_finished_ = true;
    // we must repost the finish_req_ so it can be cancelled in the destructor
    finish_req_ = world.irecv(0, master_finish_tag);
  }

  // if all tasks are done with primary rays, we need to make sure
  // that all workers are idle. Unfortunately with the priority queue,
  // we can't have the workers block while waiting, so the best we can
  // do is test whether the queues are empty at this instant.
  if(all_finished_ && !finish_time_.is_initialized()) {
    const ptrdiff_t sz = x.local_queue()->size();
    const bool workers_done = 
      mpi::all_reduce(world, (sz== -n_threads_), std::logical_and<bool>());
    const long n_sent_tot = 
      mpi::all_reduce(world, long(n_sent_), std::plus<long>());
    const long n_recv_tot = 
      mpi::all_reduce(world, n_received_, std::plus<long>());

    if(workers_done && (n_sent_tot==n_recv_tot) ) {
      // we are done
      finish_time_ = tbb::tick_count::now();
      return true;
    }
  }
  return false;
}

template<typename xfer_type>
void mcrx::mpi_master<xfer_type>::run()
{
  if(n_tasks()>1)
    printf("Master task %d starting\n",task());

  bool production_complete = false;
  tbb::tick_count last_info = tbb::tick_count::now();
  tbb::tick_count last_catchup = tbb::tick_count::now();
  const double info_interval = 10;
  min_idle_time_ = blitz::huge(double());
  int last_sent=0,last_received=0,last_loc=0,last_rays=0;

  send_handshake(1);

  if(n_tasks()>1)
    hpm::thread_start(Misc);

  while(true) {

    if(n_tasks()==1) {
      // if we only have one task, all of this is unnecessary and we
      // just sleep and periodically wake up to see if we are done
      struct timespec sleepy_time;
      sleepy_time.tv_sec = 1;
      sleepy_time.tv_nsec = 0;
      nanosleep (&sleepy_time, 0);

      if( (n_rays_>=n_rays_desired_) && x.pixel_queue()->empty() )
	break;
      continue;
    }

    const ptrdiff_t n=x.local_queue()->size();
    const tbb::tick_count now = tbb::tick_count::now();

    // see if we should tell the other tasks to stop
    if(recv_enabled_ && (n>max_queue_depth_) ) {
      // send stop request
      send_handshake(0);
    }

    process_handshakes();
    bool idle=true;
    idle = !process_incoming_rays() && idle;
    //idle = !process_outgoing_rays() && idle;
    idle = !process_locations() && idle;
    if(!idle) {
      const tbb::tick_count now(tbb::tick_count::now());
      const double idle_time=
	(now-last_action_time_).seconds();
      last_action_time_=now;
      if(idle_time<min_idle_time_)
	min_idle_time_=idle_time;
    }

    // see of we should tell the other tasks to restart sending. If we
    // are catching up, we don't send a restart until we have gone
    // through an iteration where no incoming messages were found. If
    // we stopped because the local queue was full, we restart if it's
    // down to a reasonable size.
    if (!recv_enabled_ && (n<0.5*max_queue_depth_) ) {
      send_handshake(1);
    }

    if(!production_complete && (n_rays_==n_rays_desired_) && 
       x.pixel_queue()->empty()) {
      production_complete=true;
      printf("Task %d: Production of primary rays complete.\n",task());
      }

    if(test_finish())
      break;

    /*
    if(finish_time_.is_initialized()) {
      const double time_to_finish = finish_buffer_time_ - 
	(tbb::tick_count::now()-finish_time_.get()).seconds();
      if(time_to_finish<0)
	break;
    }
    */

    if((now-last_info).seconds()>info_interval) {
      last_info=now;
      const long disp_rays = 
	x.pixel_queue()->empty() ? long(n_rays_) : 
	long(x.pixel_queue()->unsafe_size());
      printf("Task %d:\tQueue depth: %ld n_rays: %ld min master idle: %g ms\n     \tsent %ld received %ld locations %ld handshake %d\n     \trays/s: %#5.2g sent/s %#5.2g received/s %#5.2g locations/s %#5.2g\n",task(),n,disp_rays,min_idle_time_*1e3,long(n_sent_), n_received_, n_resp_, int(recv_enabled_), (disp_rays-last_rays)/info_interval,(n_sent_-last_sent)/info_interval,(n_received_-last_received)/info_interval,(n_resp_-last_loc)/info_interval);
      last_rays= disp_rays;
      last_received=n_received_;
      last_sent=n_sent_;
      last_loc=n_resp_;

      printf("Task %d send mask: ",task());
      for(int i=0; i<n_tasks(); ++i)
	printf("%d ",i==task() ? -1 : handshake_data_[i]);
      printf("\n");

      /*
      if(i_hpmdump_>5) {
	hpm::dump_data(string("Master thread ") +
		       boost::lexical_cast<string>(task()));
	i_hpmdump_ = 0;
      }
      else 
	i_hpmdump_++;
      */

      //n_sent_=0; n_received_=0; n_resp_=0;

      min_idle_time_=blitz::huge(double());
    }
  }

  terminate_workers();
  
  // add up number of rays shot to master task
  long n_rays_tot=0;
  boost::mpi::reduce(world, long(n_rays_), n_rays_tot, std::plus<long>(), 0);
  n_rays_ = n_rays_tot;

  if(n_tasks()>1)
    hpm::thread_stop(std::string("Master thread ") +
		     boost::lexical_cast<std::string>(task()));

}

#else // WITH_MPI

template<typename xfer_type>
mcrx::mpi_master<xfer_type>::mpi_master (T_xfer& xx, int nt, 
					 tbb::atomic<long>& nr,
					 long nrd)
  : x(xx), n_threads_(nt), n_rays_(nr), n_rays_desired_(nrd)
{
  x.set_master(this);
}

template<typename xfer_type>
void mcrx::mpi_master<xfer_type>::run()
{
  while(true) {
    // if we only have one task, all of this is unnecessary and we
    // just sleep and periodically wake up to see if we are done
    struct timespec sleepy_time;
    sleepy_time.tv_sec = 1;
    sleepy_time.tv_nsec = 0;
    nanosleep (&sleepy_time, 0);
    
    if( (n_rays_>=n_rays_desired_) && x.pixel_queue()->empty() )
      break;
  }

  terminate_workers();
}

#endif

/** Puts invalid items on the queue, which signals the workers to terminate. */
template<typename xfer_type>
void mcrx::mpi_master<xfer_type>::terminate_workers()
{
  // kill threads
  for(int i=0; i<n_threads_;++i)
    x.local_queue()->push(T_xfer::T_queue_item::make_invalid());
}


#endif
